Skip to content

Backport multi-format tool parser to v0.12.0#2

Merged
hanseungwook merged 4 commits intov0.12.0-ifmfrom
tgat/multi-format-tool-parser-v0.12.0
Apr 22, 2026
Merged

Backport multi-format tool parser to v0.12.0#2
hanseungwook merged 4 commits intov0.12.0-ifmfrom
tgat/multi-format-tool-parser-v0.12.0

Conversation

@hanseungwook
Copy link
Copy Markdown
Collaborator

@hanseungwook hanseungwook commented Apr 10, 2026

Summary

  • backport the multi-format tool parser to the v0.12.0 serving stack
  • add multi_format tool parser support for default, qwen3, glm, minimax, dsv32, gptoss, and python chat-template formats
  • thread chat_template_kwargs into compatible tool parsers without breaking legacy parser constructors
  • align glm and python parsing with the exact chat-template output format
  • add targeted tests for parser dispatch/extraction and the constructor compatibility shim

Testing

  • .venv/bin/python -m ruff check vllm/entrypoints/openai/serving_chat.py vllm/entrypoints/openai/serving_engine.py vllm/entrypoints/openai/serving_responses.py vllm/entrypoints/openai/tool_parsers/__init__.py vllm/entrypoints/openai/tool_parsers/multi_format_tool_parser.py tests/entrypoints/openai/test_tool_parser_kwargs.py tests/entrypoints/openai/tool_parsers/test_multi_format_tool_parser.py
  • .venv/bin/python -m pytest tests/entrypoints/openai/test_tool_parser_kwargs.py tests/entrypoints/openai/tool_parsers/test_multi_format_tool_parser.py -v
  • Result: 20 passed.

@hanseungwook
Copy link
Copy Markdown
Collaborator Author

Streaming mode is not supported at the moment

@hanseungwook hanseungwook force-pushed the tgat/multi-format-tool-parser-v0.12.0 branch from c619f38 to 677cef0 Compare April 10, 2026 21:29
- Add xllm.py with GroupRMSNorm (per-group variance, layernorm_num_groups)
- Register XllmForCausalLM in the model registry
- Supports dense xllm checkpoints (Qwen3MoE-based architecture, no MoE)
- Tested with 8B checkpoint on 8x H200 (TP=8)

Co-Authored-By: claude-flow <ruv@ruv.net>
@shauryr
Copy link
Copy Markdown
Collaborator

shauryr commented Apr 20, 2026

Testing the XllmForCausalLM support

This PR (with the xllm commit) was verified end-to-end against a real xllm 8B checkpoint.

1. Install

git clone --branch tgat/multi-format-tool-parser-v0.12.0 https://github.com/LLM360/vllm.git
cd vllm
uv venv --python 3.12 .venv
VLLM_USE_PRECOMPILED=1 uv pip install --editable .

2. Serve a checkpoint with model_type: "xllm"

.venv/bin/vllm serve /path/to/xllm/checkpoint \
    --trust-remote-code \
    --tensor-parallel-size 8 \
    --port 8000 \
    --reasoning-parser k2_v3 \
    --enable-auto-tool-choice \
    --tool-call-parser multi_format

You should see in the logs:

INFO [model.py:637] Resolved architecture: XllmForCausalLM
INFO [gpu_model_runner.py:3549] Model loading took 2.1271 GiB memory and 46.158572 seconds
INFO: Application startup complete.

3. Sanity check — simple chat

import requests
resp = requests.post("http://localhost:8000/v1/chat/completions", json={
    "model": "/path/to/xllm/checkpoint",
    "messages": [{"role": "user", "content": "What is 15 * 23?"}],
    "max_tokens": 256,
    "temperature": 0.7,
    "top_p": 0.95,
    "extra_body": {"chat_template_kwargs": {"reasoning_effort": "high"}},
})
print(resp.json()["choices"][0]["message"])

Expected: a valid response with reasoning_content and content populated (e.g. "15 × 23 = 345.").

4. Verify reasoning + tool parsers

Reasoning effort ↔ think token mapping:

  • reasoning_effort=high<think>
  • reasoning_effort=medium<think_fast>
  • reasoning_effort=low<think_faster>

The reasoning_content field should be populated on the response when the model emits a think block.

Tool calling (default format):

resp = requests.post("http://localhost:8000/v1/chat/completions", json={
    "model": "/path/to/xllm/checkpoint",
    "messages": [{"role": "user", "content": "What's the weather in Tokyo?"}],
    "tools": [{"type": "function", "function": {
        "name": "get_weather",
        "description": "Get current weather",
        "parameters": {"type": "object",
                       "properties": {"location": {"type": "string"}},
                       "required": ["location"]}
    }}],
    "tool_choice": "auto",
})
# Expected: choices[0].message.tool_calls[0].function.name == "get_weather"
#           arguments parses to {"location": "Tokyo"}

5. End-to-end Harbor task (optional)

export DAYTONA_API_KEY=<key>
export OPENAI_API_KEY=dummy
harbor run \
  -p /path/to/LiteCoder/build-go-hello-binary \
  -e daytona -a terminus-2 \
  -m "openai//path/to/xllm/checkpoint" \
  --ak api_base=http://localhost:8000/v1 \
  --ak reasoning_effort=medium \
  -n 1

Expected: reward = 1.0 — model clones a Go repo, builds the binary, writes output to /app/output.txt.

What the changes actually do

  • vllm/model_executor/models/xllm.py: new model implementation. Structurally identical to Llama (same weight names) but replaces RMSNorm with GroupRMSNorm, which computes variance per group of hidden_size/n_groups dims instead of the full hidden dimension. Required because xllm checkpoints are trained with layernorm_num_groups=4 — using standard RMSNorm would produce numerically wrong outputs.
  • vllm/model_executor/models/registry.py: one-line addition to map "XllmForCausalLM"xllm module.

Without these changes, vLLM fails with "unsupported architecture: XllmForCausalLM" at startup.

@shauryr
Copy link
Copy Markdown
Collaborator

shauryr commented Apr 20, 2026

✅ Test results on xllm 8B checkpoint

Ran the test plan from the previous comment against a real xllm 8B checkpoint (config.json architecture = XllmForCausalLM, layernorm_num_groups=4). Server brought up with:

.venv/bin/vllm serve /path/to/xllm/checkpoint \
    --trust-remote-code \
    --tensor-parallel-size 8 \
    --port 8000 \
    --reasoning-parser k2_v3 \
    --enable-auto-tool-choice \
    --tool-call-parser multi_format

Logs confirm architecture resolution:

INFO [model.py:637] Resolved architecture: XllmForCausalLM
INFO [gpu_model_runner.py:3549] Model loading took 2.1271 GiB memory and 46.158572 seconds
INFO: Application startup complete.

Parser + model test suite: 14/15 passed

# Test Result Notes
1 Reasoning (high) — <think> ✅ PASS "15 × 23 = 345." + populated reasoning_content
2 Reasoning (medium) — <think_fast> ✅ PASS Correct capital of Japan, 967 chars reasoning
3 Reasoning (low) — <think_faster> ✅ PASS "Hello! How can I assist you today?"
4 Tool calling — default format ✅ PASS Parsed get_weather(location="Tokyo"), arguments valid JSON
5 Tool calling — gptoss format ✅ PASS Parsed search tool call correctly
6 Plain chat (no tools) ❌ FAIL Empty content — checkpoint ran out of tokens in reasoning; not a parser/model-loader bug
7 Multi-turn conversation ✅ PASS Correctly computed 56 ÷ 4 = 14

Math answers being correct is the key signal that GroupRMSNorm (per-group variance with n_groups=4) is mathematically correct — using standard RMSNorm would have produced garbled outputs.

What's verified by this run

  1. XllmForCausalLM resolves via the registry addition.
  2. Weight loading works — same weight names as Llama, 36 shards loaded, 2.1 GiB/GPU across TP=8.
  3. GroupRMSNorm computes correctly (numeric answers are right, reasoning is coherent).
  4. Both the k2_v3 reasoning parser and multi_format tool parser work correctly on top of the new model.

End-to-end sanity: Harbor task

Also ran a full Harbor LiteCoder/build-go-hello-binary trial against this checkpoint with the terminus-2 agent pointed at the local vLLM server (api_base=http://localhost:8000/v1). The model successfully cloned the Go repo, built the binary, and wrote the output file — reward = 1.0 in a ~1m44s run. So the changes hold up under a real multi-turn agent loop, not just single-shot requests.

@hanseungwook hanseungwook merged commit ce42a74 into v0.12.0-ifm Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants